"""MSA pre-parsing script for AF3 dataset."""

import json
import multiprocessing as mp
from pathlib import Path

import boto3
import click
import numpy as np
from tqdm import tqdm

from openfold3.core.data.io.s3 import download_file_from_s3
from openfold3.core.data.io.sequence.msa import parse_msas_direct, standardize_filepaths

_worker_session = None


def _init_worker(profile_name: str = "openfold") -> None:
    """Initialize the boto3 session in each worker."""
    global _worker_session
    _worker_session = boto3.Session(profile_name=profile_name)


# TODO: merge with existing MSA pre-parsing code
@click.command()
@click.option(
    "--alignment_array_directory",
    help="Output directory to which the per-chain MSA npz files are to be saved.",
    type=click.Path(
        exists=False,
        file_okay=False,
        dir_okay=True,
        path_type=Path,
    ),
)
@click.option(
    "--dataset_cache_file",
    type=str,
    help="path to dataset cache json file, generated by"
    "make_monomer_distillation_set_datacache.py",
)
@click.option(
    "--s3_config",
    type=str,
    help="Path to the s3 client config file.",
)
@click.option(
    "--num_workers",
    default=1,
    type=int,
    help=(
        "Number of workers to parallelize the template cache computation and filtering"
        " over."
    ),
)
@click.option(
    "--max_seq_config",
    type=str,
    default=None,
    help="json string mapping alignment db to max seq counts",
)
def main(
    alignment_array_directory: Path,
    dataset_cache_file: str,
    num_workers: int,
    s3_config: str,
    max_seq_config: str,
):
    """Preparse multiple sequence alignments for AF3 dataset."""
    with open(dataset_cache_file) as f:
        dataset_cache = json.load(f)
    with open(s3_config) as f:
        s3_config = json.load(f)

    max_seq_counts = json.loads(max_seq_config)

    alignment_array_directory.mkdir(parents=True, exist_ok=True)
    rep_chain_dir_iterator = list(dataset_cache["structure_data"].keys())

    # Create template cache for each query chain
    wrapped_msa_preparser = _MsaPreparser(
        alignment_array_directory, max_seq_counts, s3_config
    )
    if num_workers > 1:
        with mp.Pool(
            num_workers, initializer=_init_worker, initargs=(s3_config["profile"],)
        ) as pool:
            for _ in tqdm(
                pool.imap_unordered(
                    wrapped_msa_preparser,
                    rep_chain_dir_iterator,
                    chunksize=1,
                ),
                total=len(rep_chain_dir_iterator),
                desc="Pre-parsing MSAs",
            ):
                pass
    else:
        for chain in tqdm(rep_chain_dir_iterator):
            wrapped_msa_preparser(chain)
        pass


def preparse_msas(
    alignments_directory: Path,
    alignment_array_directory: Path,
    max_seq_counts: dict[str, int],
    rep_pdb_chain_id: str,
) -> None:
    file_list = standardize_filepaths(alignments_directory / Path(rep_pdb_chain_id))
    msas = parse_msas_direct(
        file_list=file_list,
        max_seq_counts=max_seq_counts,
    )
    alignment_array_directory.mkdir(parents=True, exist_ok=True)

    msas_preparsed = {}
    for k, v in msas.items():
        msas_preparsed[k] = v.to_dict()

    np.savez_compressed(
        alignment_array_directory / Path(f"{rep_pdb_chain_id}.npz"), **msas_preparsed
    )


class _MsaPreparser:
    def __init__(
        self,
        alignment_array_directory: Path,
        max_seq_counts: dict[str, int],
        s3_config: dict[str, str],
    ) -> None:
        """Wrapper class for pre-parsing a directory of raw MSA files.

        This wrapper around `preparse_msas` is needed for multiprocessing, so that we
        can pass the constant arguments in a convenient way catch any errors that would
        crash the workers, and change the function call to accept a single Iterable.

        The wrapper is written as a class object because multiprocessing doesn't support
        decorator-like nested functions.

        Attributes:
            alignments_directory:
                Directory containing per-chain folders with multiple sequence
                alignments.
            alignment_array_directory:
                Output directory to which the per-chain MSA npz files are to be saved.

        """
        self.alignment_array_directory = alignment_array_directory
        self.max_seq_counts = max_seq_counts
        self.s3_config = s3_config

    def __call__(self, rep_pdb_chain_id: str) -> None:
        tmp_dir = Path(f"/tmp/alignments/{rep_pdb_chain_id}")
        tmp_dir.mkdir(parents=True, exist_ok=True)
        global _worker_session
        try:
            download_file_from_s3(
                bucket=self.s3_config["bucket"],
                prefix=f"{self.s3_config['prefix']}/{rep_pdb_chain_id}",
                filename="concat_cfdb_uniref100_filtered.a3m",
                outfile=str(tmp_dir / "concat_cfdb_uniref100_filtered.a3m"),
                session=_worker_session,
            )

            preparse_msas(
                Path("/tmp/alignments"),
                self.alignment_array_directory,
                self.max_seq_counts,
                rep_pdb_chain_id,
            )
            (tmp_dir / "concat_cfdb_uniref100_filtered.a3m").unlink()
            tmp_dir.rmdir()
        except Exception as e:
            print(f"Failed to preparse MSAs for chain {rep_pdb_chain_id}:\n{e}\n")


if __name__ == "__main__":
    main()
